{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2934706f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | default_exp _components.task_streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66b86144",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export \n",
    "\n",
    "import asyncio\n",
    "import sys\n",
    "from abc import ABC, abstractmethod\n",
    "\n",
    "from asyncio import Task\n",
    "from typing import *\n",
    "\n",
    "import anyio\n",
    "from aiokafka import ConsumerRecord\n",
    "\n",
    "from logging import Logger\n",
    "from fastkafka._components.logger import get_logger\n",
    "from fastkafka._components.meta import export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "777beaa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime, timedelta\n",
    "\n",
    "from anyio import create_task_group, create_memory_object_stream, ExceptionGroup\n",
    "from unittest.mock import Mock, MagicMock, AsyncMock\n",
    "\n",
    "import asyncer\n",
    "import pytest\n",
    "from aiokafka import ConsumerRecord, TopicPartition\n",
    "from pydantic import BaseModel, Field, HttpUrl, NonNegativeInt\n",
    "from tqdm.notebook import tqdm\n",
    "from types import CoroutineType\n",
    "\n",
    "from fastkafka._components.logger import suppress_timestamps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb6e2212",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "logger = get_logger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11c31bb0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] __main__: ok\n"
     ]
    }
   ],
   "source": [
    "suppress_timestamps()\n",
    "logger = get_logger(__name__, level=20)\n",
    "logger.info(\"ok\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c8da6a8",
   "metadata": {},
   "source": [
    "## anyio stream is not running tasks in parallel\n",
    "> Memory object stream is buffering the messages but the messages are consumed one by one and a new one is consumed only after the last one is finished"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80816012",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9285928200db4a55b0124f898b0e4984",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd84a793afbc42f4ab61ecd78e211c02",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_msgs = 5\n",
    "latency = 0.2\n",
    "\n",
    "receive_pbar = tqdm(total=num_msgs*2)\n",
    "\n",
    "async def latency_task():\n",
    "    receive_pbar.update(1)\n",
    "    await asyncio.sleep(latency)\n",
    "    receive_pbar.update(1)\n",
    "\n",
    "async def process_message_callback(\n",
    "        receive_stream,\n",
    ") -> None:\n",
    "    async with receive_stream:\n",
    "        async for task in receive_stream:\n",
    "            await task\n",
    "\n",
    "send_stream, receive_stream = anyio.create_memory_object_stream(\n",
    "    max_buffer_size=num_msgs\n",
    ")\n",
    "\n",
    "t0 = datetime.now()\n",
    "async with anyio.create_task_group() as tg:\n",
    "    tg.start_soon(process_message_callback, receive_stream)\n",
    "    async with send_stream:\n",
    "        for i in tqdm(range(num_msgs)):\n",
    "            await send_stream.send(latency_task())\n",
    "            \n",
    "assert datetime.now() - t0 >= timedelta(seconds=latency*num_msgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7a41e0a",
   "metadata": {},
   "source": [
    "To solve this, we can create tasks from coroutines and let them run in background while the receive_stream is spawning new tasks whithout being blocked by previous ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842893fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b686c64cc60d4a31b12d34ad80a213a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6a76d2fb99644238ac475883f1ad197",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ok\n"
     ]
    }
   ],
   "source": [
    "num_msgs = 10_000\n",
    "latency = 4.0\n",
    "\n",
    "receive_pbar = tqdm(total=num_msgs*2)\n",
    "\n",
    "async def latency_task():\n",
    "    receive_pbar.update(1)\n",
    "    await asyncio.sleep(latency)\n",
    "    receive_pbar.update(1)\n",
    "\n",
    "tasks = set()\n",
    "\n",
    "async def process_message_callback(\n",
    "        receive_stream,\n",
    ") -> None:\n",
    "    async with receive_stream:\n",
    "        async for f in receive_stream:\n",
    "            task: asyncio.Task = asyncio.create_task(f())\n",
    "            tasks.add(task)\n",
    "            task.add_done_callback(lambda task=task, tasks=tasks: tasks.remove(task))\n",
    "\n",
    "send_stream, receive_stream = anyio.create_memory_object_stream(\n",
    "    max_buffer_size=num_msgs\n",
    ")\n",
    "\n",
    "t0 = datetime.now()\n",
    "async with anyio.create_task_group() as tg:\n",
    "    tg.start_soon(process_message_callback, receive_stream)\n",
    "    async with send_stream:\n",
    "        for i in tqdm(range(num_msgs)):\n",
    "            await send_stream.send(latency_task)\n",
    "\n",
    "await asyncio.sleep(latency/2)\n",
    "receive_pbar.refresh()\n",
    "assert receive_pbar.n == num_msgs, receive_pbar.n\n",
    "\n",
    "while len(tasks) > 0:\n",
    "    await asyncio.sleep(0)\n",
    "await send_stream.aclose()\n",
    "    \n",
    "receive_pbar.close()\n",
    "assert datetime.now() - t0 <= timedelta(seconds=latency+5.0)\n",
    "assert receive_pbar.n == num_msgs*2, receive_pbar.n\n",
    "\n",
    "print(\"ok\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "839cceba",
   "metadata": {},
   "source": [
    "## Keeping track of tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b255e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "class TaskPool:\n",
    "    def __init__(\n",
    "        self,\n",
    "        size: int = 100_000,\n",
    "        on_error: Optional[Callable[[BaseException], None]] = None,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Initializes a TaskPool instance.\n",
    "\n",
    "        Args:\n",
    "            size: The size of the task pool. Defaults to 100,000.\n",
    "            on_error: Optional callback function to handle task errors. Defaults to None.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        self.size = size\n",
    "        self.pool: Set[Task] = set()\n",
    "        self.on_error = on_error\n",
    "        self.finished = False\n",
    "\n",
    "    async def add(self, item: Task) -> None:\n",
    "        \"\"\"\n",
    "        Adds a task to the task pool.\n",
    "\n",
    "        Args:\n",
    "            item: The task to be added.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        while len(self.pool) >= self.size:\n",
    "            await asyncio.sleep(0)\n",
    "        self.pool.add(item)\n",
    "        item.add_done_callback(self.discard)\n",
    "\n",
    "    def discard(self, task: Task) -> None:\n",
    "        \"\"\"\n",
    "        Discards a completed task from the task pool.\n",
    "\n",
    "        Args:\n",
    "            task: The completed task to be discarded.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        e = task.exception()\n",
    "        if e is not None and self.on_error is not None:\n",
    "            try:\n",
    "                self.on_error(e)\n",
    "            except Exception as ee:\n",
    "                logger.warning(\n",
    "                    f\"Exception {ee} raised when calling on_error() callback: {e}\"\n",
    "                )\n",
    "\n",
    "        self.pool.discard(task)\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        \"\"\"\n",
    "        Returns the number of tasks in the task pool.\n",
    "\n",
    "        Returns:\n",
    "            The number of tasks in the task pool.\n",
    "        \"\"\"\n",
    "        return len(self.pool)\n",
    "\n",
    "    async def __aenter__(self) -> \"TaskPool\":\n",
    "        self.finished = False\n",
    "        return self\n",
    "\n",
    "    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:\n",
    "        while len(self) > 0:\n",
    "            await asyncio.sleep(0)\n",
    "        self.finished = True\n",
    "\n",
    "    @staticmethod\n",
    "    def log_error(logger: Logger) -> Callable[[Exception], None]:\n",
    "        \"\"\"\n",
    "        Creates a decorator that logs errors using the specified logger.\n",
    "\n",
    "        Args:\n",
    "            logger: The logger to use for error logging.\n",
    "\n",
    "        Returns:\n",
    "            The decorator function.\n",
    "        \"\"\"\n",
    "\n",
    "        def _log_error(e: Exception, logger: Logger = logger) -> None:\n",
    "            logger.warning(f\"{e=}\")\n",
    "\n",
    "        return _log_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce06ffff",
   "metadata": {},
   "outputs": [],
   "source": [
    "async with TaskPool() as tp:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0625abf",
   "metadata": {},
   "outputs": [],
   "source": [
    "async def f():\n",
    "    await asyncio.sleep(2)\n",
    "\n",
    "pool = TaskPool()\n",
    "assert len(pool) == 0\n",
    "\n",
    "async with pool:\n",
    "    task = asyncio.create_task(f())\n",
    "    await pool.add(task)\n",
    "    assert len(pool) == 1\n",
    "\n",
    "assert len(pool) == 0, len(pool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f931de1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WARNING] __main__: e=RuntimeError('funny error')\n"
     ]
    }
   ],
   "source": [
    "async def f():\n",
    "    raise RuntimeError(\"funny error\")\n",
    "\n",
    "        \n",
    "    return _log_error\n",
    "    \n",
    "pool = TaskPool(on_error=TaskPool.log_error(logger))\n",
    "\n",
    "async with pool:\n",
    "    task = asyncio.create_task(f())\n",
    "    await pool.add(task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd66e07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "class ExceptionMonitor:\n",
    "    def __init__(self) -> None:\n",
    "        \"\"\"\n",
    "        Initializes an ExceptionMonitor instance.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        self.exceptions: List[Exception] = []\n",
    "        self.exception_found = False\n",
    "\n",
    "    def on_error(self, e: Exception) -> None:\n",
    "        \"\"\"\n",
    "        Handles an error by storing the exception.\n",
    "\n",
    "        Args:\n",
    "            e: The exception to be handled.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        self.exceptions.append(e)\n",
    "        self.exception_found = True\n",
    "\n",
    "    def _monitor_step(self) -> None:\n",
    "        \"\"\"\n",
    "        Raises the next exception in the queue.\n",
    "\n",
    "        Returns:\n",
    "            None\n",
    "        \"\"\"\n",
    "        if len(self.exceptions) > 0:\n",
    "            e = self.exceptions.pop(0)\n",
    "            raise e\n",
    "\n",
    "    async def __aenter__(self) -> \"ExceptionMonitor\":\n",
    "        return self\n",
    "\n",
    "    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:\n",
    "        while len(self.exceptions) > 0:\n",
    "            self._monitor_step()\n",
    "            await asyncio.sleep(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16135fe2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "e=<ExceptionInfo RuntimeError('very funny error.') tblen=4>\n"
     ]
    }
   ],
   "source": [
    "no_tasks = 1\n",
    "\n",
    "async def f():\n",
    "    raise RuntimeError(f\"very funny error.\")\n",
    "\n",
    "\n",
    "exception_monitor = ExceptionMonitor()\n",
    "pool = TaskPool(on_error=exception_monitor.on_error)\n",
    "\n",
    "async def create_tasks():\n",
    "    for _ in range(no_tasks):\n",
    "        task = asyncio.create_task(f())\n",
    "        await pool.add(task)\n",
    "        await asyncio.sleep(0.1) # otherwise the tasks get created before any of them throws an exception\n",
    "        if exception_monitor.exception_found:\n",
    "            break\n",
    "        \n",
    "with pytest.raises(RuntimeError) as e:\n",
    "    async with exception_monitor, pool:\n",
    "        async with asyncer.create_task_group() as tg:\n",
    "            tg.soonify(create_tasks)()\n",
    "            \n",
    "print(f\"{e=}\")\n",
    "assert exception_monitor.exceptions == [], len(exception_monitor.exceptions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d46e1354",
   "metadata": {},
   "source": [
    "# Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d56d47ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "class StreamExecutor(ABC):\n",
    "    @abstractmethod\n",
    "    async def run(  # type: ignore\n",
    "        self,\n",
    "        *,\n",
    "        is_shutting_down_f: Callable[[], bool],\n",
    "        generator: Callable[[], Awaitable[ConsumerRecord]],\n",
    "        processor: Callable[[ConsumerRecord], Awaitable[None]],\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Abstract method for running the stream executor.\n",
    "\n",
    "        Args:\n",
    "            is_shutting_down_f: Function to check if the executor is shutting down.\n",
    "            generator: Generator function for retrieving consumer records.\n",
    "            processor: Processor function for processing consumer records.\n",
    "        \"\"\"\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61503ef7",
   "metadata": {},
   "source": [
    "## Streaming tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ac5c907",
   "metadata": {},
   "outputs": [],
   "source": [
    "mock = Mock()\n",
    "async_mock = asyncer.asyncify(mock)\n",
    "\n",
    "async def process_items(receive_stream):\n",
    "    async with receive_stream:\n",
    "        async for item in receive_stream:\n",
    "            task = asyncio.create_task(async_mock(item))\n",
    "            await pool.add(task)\n",
    "\n",
    "send_stream, receive_stream = create_memory_object_stream()\n",
    "pool = TaskPool()\n",
    "\n",
    "async with pool:\n",
    "    async with create_task_group() as tg:\n",
    "        tg.start_soon(process_items, receive_stream)\n",
    "        async with send_stream:\n",
    "            await send_stream.send(f\"hi\")\n",
    "\n",
    "mock.assert_called()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a0b5bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "def _process_items_task(  # type: ignore\n",
    "    processor: Callable[[ConsumerRecord], Awaitable[None]], task_pool: TaskPool\n",
    ") -> Callable[\n",
    "    [\n",
    "        anyio.streams.memory.MemoryObjectReceiveStream,\n",
    "        Callable[[ConsumerRecord], Awaitable[None]],\n",
    "        bool,\n",
    "    ],\n",
    "    Coroutine[Any, Any, Awaitable[None]],\n",
    "]:\n",
    "    async def _process_items_wrapper(  # type: ignore\n",
    "        receive_stream: anyio.streams.memory.MemoryObjectReceiveStream,\n",
    "        processor: Callable[[ConsumerRecord], Awaitable[None]] = processor,\n",
    "        task_pool=task_pool,\n",
    "    ):\n",
    "        async with receive_stream:\n",
    "            async for msg in receive_stream:\n",
    "                task: asyncio.Task = asyncio.create_task(processor(msg))  # type: ignore\n",
    "                await task_pool.add(task)\n",
    "\n",
    "    return _process_items_wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2711f803",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "@export(\"fastkafka.executors\")\n",
    "class DynamicTaskExecutor(StreamExecutor):\n",
    "    \"\"\"A class that implements a dynamic task executor for processing consumer records.\n",
    "\n",
    "    The DynamicTaskExecutor class extends the StreamExecutor class and provides functionality\n",
    "    for running a tasks in parallel using asyncio.Task.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        throw_exceptions: bool = False,\n",
    "        max_buffer_size: int = 100_000,\n",
    "        size: int = 100_000,\n",
    "    ):\n",
    "        \"\"\"Create an instance of DynamicTaskExecutor\n",
    "\n",
    "        Args:\n",
    "            throw_exceptions: Flag indicating whether exceptions should be thrown ot logged.\n",
    "                Defaults to False.\n",
    "            max_buffer_size: Maximum buffer size for the memory object stream.\n",
    "                Defaults to 100_000.\n",
    "            size: Size of the task pool. Defaults to 100_000.\n",
    "        \"\"\"\n",
    "        self.throw_exceptions = throw_exceptions\n",
    "        self.max_buffer_size = max_buffer_size\n",
    "        self.exception_monitor = ExceptionMonitor()\n",
    "        self.task_pool = TaskPool(\n",
    "            on_error=self.exception_monitor.on_error  # type: ignore\n",
    "            if throw_exceptions\n",
    "            else TaskPool.log_error(logger),\n",
    "            size=size,\n",
    "        )\n",
    "\n",
    "    async def run(  # type: ignore\n",
    "        self,\n",
    "        *,\n",
    "        is_shutting_down_f: Callable[[], bool],\n",
    "        generator: Callable[[], Awaitable[ConsumerRecord]],\n",
    "        processor: Callable[[ConsumerRecord], Awaitable[None]],\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Runs the dynamic task executor.\n",
    "\n",
    "        Args:\n",
    "            is_shutting_down_f: Function to check if the executor is shutting down.\n",
    "            generator: Generator function for retrieving consumer records.\n",
    "            processor: Processor function for processing consumer records.\n",
    "        \"\"\"\n",
    "        send_stream, receive_stream = anyio.create_memory_object_stream(\n",
    "            max_buffer_size=self.max_buffer_size\n",
    "        )\n",
    "\n",
    "        async with self.exception_monitor, self.task_pool:\n",
    "            async with anyio.create_task_group() as tg:\n",
    "                tg.start_soon(\n",
    "                    _process_items_task(processor, self.task_pool), receive_stream\n",
    "                )\n",
    "                async with send_stream:\n",
    "                    while not is_shutting_down_f():\n",
    "                        if (\n",
    "                            self.exception_monitor.exception_found\n",
    "                            and self.throw_exceptions\n",
    "                        ):\n",
    "                            break\n",
    "                        msgs = await generator()\n",
    "                        for msg in msgs:\n",
    "                            await send_stream.send(msg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc13156",
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_shutting_down_f(call_count:int = 1) -> Callable[[], bool]:\n",
    "    count = {\"count\": 0}\n",
    "    \n",
    "    def _is_shutting_down_f(count=count, call_count:int = call_count):\n",
    "        if count[\"count\"]>=call_count:\n",
    "            return True\n",
    "        else:\n",
    "            count[\"count\"] = count[\"count\"] + 1\n",
    "            return False\n",
    "        \n",
    "    return _is_shutting_down_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0385d5ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = is_shutting_down_f()\n",
    "assert f() == False\n",
    "assert f() == True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c65ce4ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "msg\n"
     ]
    }
   ],
   "source": [
    "async def produce():\n",
    "    return [\"msg\"]\n",
    "\n",
    "\n",
    "async def consume(msg):\n",
    "    print(msg)\n",
    "\n",
    "\n",
    "stream = DynamicTaskExecutor()\n",
    "\n",
    "await stream.run(\n",
    "    is_shutting_down_f=is_shutting_down_f(),\n",
    "    generator=produce,\n",
    "    processor=consume,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d82173",
   "metadata": {},
   "outputs": [],
   "source": [
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "\n",
    "stream = DynamicTaskExecutor()\n",
    "\n",
    "await stream.run(\n",
    "    is_shutting_down_f=is_shutting_down_f(),\n",
    "    generator=mock_produce,\n",
    "    processor=mock_consume,\n",
    ")\n",
    "\n",
    "mock_produce.assert_awaited()\n",
    "mock_consume.assert_awaited_with(\"msg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7494a22c",
   "metadata": {},
   "outputs": [],
   "source": [
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "\n",
    "stream = DynamicTaskExecutor()\n",
    "\n",
    "await stream.run(\n",
    "    is_shutting_down_f=is_shutting_down_f(),\n",
    "    generator=mock_produce,\n",
    "    processor=mock_consume,\n",
    ")\n",
    "\n",
    "mock_produce.assert_called()\n",
    "mock_consume.assert_called_with(\"msg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "128ac8f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_msgs = 13\n",
    "\n",
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "mock_consume.side_effect = RuntimeError()\n",
    "\n",
    "stream = DynamicTaskExecutor(throw_exceptions=True)\n",
    "\n",
    "with pytest.raises(RuntimeError) as e:\n",
    "    await stream.run(\n",
    "        is_shutting_down_f=is_shutting_down_f(num_msgs),\n",
    "        generator=mock_produce,\n",
    "        processor=mock_consume,\n",
    "    )\n",
    "\n",
    "mock_produce.assert_called()\n",
    "mock_consume.assert_awaited_with(\"msg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd855e8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n",
      "[WARNING] __main__: e=RuntimeError()\n"
     ]
    }
   ],
   "source": [
    "num_msgs = 13\n",
    "\n",
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "mock_consume.side_effect = RuntimeError()\n",
    "\n",
    "stream = DynamicTaskExecutor()\n",
    "\n",
    "await stream.run(\n",
    "    is_shutting_down_f=is_shutting_down_f(num_msgs),\n",
    "    generator=mock_produce,\n",
    "    processor=mock_consume,\n",
    ")\n",
    "\n",
    "mock_produce.assert_called()\n",
    "mock_consume.assert_awaited_with(\"msg\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "276db314",
   "metadata": {},
   "source": [
    "## Awaiting coroutines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d595bbb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "def _process_items_coro(  # type: ignore\n",
    "    processor: Callable[[ConsumerRecord], Awaitable[None]],\n",
    "    throw_exceptions: bool,\n",
    ") -> Callable[\n",
    "    [\n",
    "        anyio.streams.memory.MemoryObjectReceiveStream,\n",
    "        Callable[[ConsumerRecord], Awaitable[None]],\n",
    "        bool,\n",
    "    ],\n",
    "    Coroutine[Any, Any, Awaitable[None]],\n",
    "]:\n",
    "    async def _process_items_wrapper(  # type: ignore\n",
    "        receive_stream: anyio.streams.memory.MemoryObjectReceiveStream,\n",
    "        processor: Callable[[ConsumerRecord], Awaitable[None]] = processor,\n",
    "        throw_exceptions: bool = throw_exceptions,\n",
    "    ) -> Awaitable[None]:\n",
    "        async with receive_stream:\n",
    "            async for msg in receive_stream:\n",
    "                try:\n",
    "                    await processor(msg)\n",
    "                except Exception as e:\n",
    "                    if throw_exceptions:\n",
    "                        raise e\n",
    "                    else:\n",
    "                        logger.warning(f\"{e=}\")\n",
    "\n",
    "    return _process_items_wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a81d03b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "@export(\"fastkafka.executors\")\n",
    "class SequentialExecutor(StreamExecutor):\n",
    "    \"\"\"A class that implements a sequential executor for processing consumer records.\n",
    "\n",
    "    The SequentialExecutor class extends the StreamExecutor class and provides functionality\n",
    "    for running processing tasks in sequence by awaiting their coroutines.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        throw_exceptions: bool = False,\n",
    "        max_buffer_size: int = 100_000,\n",
    "    ):\n",
    "        \"\"\"Create an instance of SequentialExecutor\n",
    "\n",
    "        Args:\n",
    "            throw_exceptions: Flag indicating whether exceptions should be thrown or logged.\n",
    "                Defaults to False.\n",
    "            max_buffer_size: Maximum buffer size for the memory object stream.\n",
    "                Defaults to 100_000.\n",
    "        \"\"\"\n",
    "        self.throw_exceptions = throw_exceptions\n",
    "        self.max_buffer_size = max_buffer_size\n",
    "\n",
    "    async def run(  # type: ignore\n",
    "        self,\n",
    "        *,\n",
    "        is_shutting_down_f: Callable[[], bool],\n",
    "        generator: Callable[[], Awaitable[ConsumerRecord]],\n",
    "        processor: Callable[[ConsumerRecord], Awaitable[None]],\n",
    "    ) -> None:\n",
    "        \"\"\"\n",
    "        Runs the sequential executor.\n",
    "\n",
    "        Args:\n",
    "            is_shutting_down_f: Function to check if the executor is shutting down.\n",
    "            generator: Generator function for retrieving consumer records.\n",
    "            processor: Processor function for processing consumer records.\n",
    "        \"\"\"\n",
    "\n",
    "        send_stream, receive_stream = anyio.create_memory_object_stream(\n",
    "            max_buffer_size=self.max_buffer_size\n",
    "        )\n",
    "\n",
    "        async with anyio.create_task_group() as tg:\n",
    "            tg.start_soon(\n",
    "                _process_items_coro(processor, self.throw_exceptions), receive_stream\n",
    "            )\n",
    "            async with send_stream:\n",
    "                while not is_shutting_down_f():\n",
    "                    msgs = await generator()\n",
    "                    for msg in msgs:\n",
    "                        await send_stream.send(msg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fa058a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_msgs = 13\n",
    "\n",
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "mock_consume.side_effect = RuntimeError(\"Funny error\")\n",
    "\n",
    "stream = SequentialExecutor(throw_exceptions=True)\n",
    "\n",
    "with pytest.raises(ExceptionGroup) as e:\n",
    "    await stream.run(\n",
    "        is_shutting_down_f=is_shutting_down_f(num_msgs),\n",
    "        generator=mock_produce,\n",
    "        processor=mock_consume,\n",
    "    )\n",
    "\n",
    "mock_produce.assert_called()\n",
    "mock_consume.assert_awaited_with(\"msg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9fa60fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n",
      "[WARNING] __main__: e=RuntimeError('Funny error')\n"
     ]
    }
   ],
   "source": [
    "num_msgs = 13\n",
    "\n",
    "mock_produce = AsyncMock(spec=CoroutineType, return_value=[\"msg\"])\n",
    "mock_consume = AsyncMock(spec=CoroutineType)\n",
    "mock_consume.side_effect = RuntimeError(\"Funny error\")\n",
    "\n",
    "stream = SequentialExecutor()\n",
    "\n",
    "await stream.run(\n",
    "    is_shutting_down_f=is_shutting_down_f(num_msgs),\n",
    "    generator=mock_produce,\n",
    "    processor=mock_consume,\n",
    ")\n",
    "\n",
    "mock_produce.assert_called()\n",
    "mock_consume.assert_awaited_with(\"msg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc03222",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | export\n",
    "\n",
    "\n",
    "def get_executor(executor: Union[str, StreamExecutor, None] = None) -> StreamExecutor:\n",
    "    \"\"\"\n",
    "    Returns an instance of the specified executor.\n",
    "\n",
    "    Args:\n",
    "        executor: Executor instance or name of the executor.\n",
    "\n",
    "    Returns:\n",
    "        Instance of the specified executor.\n",
    "\n",
    "    Raises:\n",
    "        AttributeError: If the executor is not found.\n",
    "    \"\"\"\n",
    "    if isinstance(executor, StreamExecutor):\n",
    "        return executor\n",
    "    elif executor is None:\n",
    "        executor = \"SequentialExecutor\"\n",
    "    return getattr(sys.modules[\"fastkafka._components.task_streaming\"], executor)()  # type: ignore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab892f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "for executor in [None, \"SequentialExecutor\", SequentialExecutor()]:\n",
    "    actual = get_executor(executor)\n",
    "    assert actual.__class__.__qualname__ == \"SequentialExecutor\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bba3526",
   "metadata": {},
   "outputs": [],
   "source": [
    "for executor in [\"DynamicTaskExecutor\", DynamicTaskExecutor()]:\n",
    "    actual = get_executor(executor)\n",
    "    assert actual.__class__.__qualname__ == \"DynamicTaskExecutor\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
